{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trainer control\n",
    "千帆Python SDK 在使用[trainer 实现训练微调](./trainer_finetune_dataset2deploy.ipynb)的基础上，SDK还提供了灵活的事件回调、以及trainer的可恢复的特性，以下以新建训练任务，并注册EventHandler，遇到报错之后进行resume进行演示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install \"qianfan>=0.2.2\" -U"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.2.2'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import qianfan\n",
    "qianfan.__version__"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 前置准备\n",
    "- 初始化千帆安全认证AK、SK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "\n",
    "os.environ[\"QIANFAN_ACCESS_KEY\"] = \"your_ak\"\n",
    "os.environ[\"QIANFAN_SECRET_KEY\"] = \"your_sk\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 导入依赖\n",
    "- `qianfan.trainer.consts` trainer使用中所用到的常量\n",
    "- `qianfan.resources.console.consts` api层面定义的字段常量\n",
    "- `qianfan.trainer.configs` trainer使用所需要的config配置数据类\n",
    "- `qianfan.trainer.LLMFinetune` 大语言模型fine-tune任务Trainer实现\n",
    "- `qianfan.trainer.Service` service类，用于表示平台的模型服务，可以通过trainer.result获取\n",
    "- `qianfan.dataset.Dataset` 千帆dataset类，用于管理千帆平台、本地、第三方数据集的导入导出，数据清洗等操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qianfan.trainer.consts import ActionState\n",
    "from qianfan.model.consts import ServiceType\n",
    "from qianfan.resources.console import consts as console_consts\n",
    "from qianfan.trainer.configs import TrainConfig\n",
    "from qianfan.model.configs import DeployConfig\n",
    "from qianfan.resources import QfMessages\n",
    "from qianfan.trainer import LLMFinetune, Service\n",
    "from qianfan.dataset import Dataset\n",
    "from typing import cast\n",
    "from qianfan.utils import enable_log\n",
    "import logging\n",
    "\n",
    "enable_log(logging.INFO)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EventHandler\n",
    "\n",
    "如果需要在训练过程中监控每个阶段的各个节点的状态，可以通过事件回调函数来实现，通过事件的对应的action_state可以获取当前的action的运行情况以实现对应的业务回调，插入自定义逻辑"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qianfan.model import Model\n",
    "from qianfan.dataset import Dataset\n",
    "\n",
    "# 首先需要先加载测试数据集，这里以加载平台预置数据集为例子：\n",
    "ds = Dataset.load(qianfan_dataset_id=15074, is_download_to_local=False)\n",
    "trainer = LLMFinetune(\n",
    "    train_type=\"ERNIE-Bot-turbo-0725\",\n",
    "    train_config=TrainConfig(\n",
    "        epoch=1,\n",
    "        learning_rate=0.0003,\n",
    "        max_seq_len=4096,\n",
    "        peft_type=\"LoRA\",\n",
    "    ),\n",
    "    dataset=ds,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qianfan.trainer.event import Event, EventHandler\n",
    "\n",
    "testset: Dataset = Dataset.load(data_file=\"./data/fin_cqa_test.jsonl\")\n",
    "# 定义自己的EventHandler，并实现dispatch方法\n",
    "class InferAfterSFT(EventHandler):\n",
    "    target_action: str\n",
    "    def __init__(self, target_action: str) -> None:\n",
    "        super().__init__()\n",
    "        self.target_action = target_action\n",
    "\n",
    "    def dispatch(self, event: Event) -> None:\n",
    "        print(\"receive: <\", event)\n",
    "        if self.target_action == event.action_id and event.action_state == ActionState.Done:\n",
    "            svc = cast(Service, event.data[\"service\"])\n",
    "            print(\"svc\", svc)\n",
    "            for row in testset.list():\n",
    "                msgs = QfMessages()\n",
    "                msgs.append(row[0][0][\"prompt\"], \"user\")\n",
    "                svc.exec({\"messages\":\"msgs\"})\n",
    "                print(\"row infer result\", row)\n",
    "            \n",
    "\n",
    "eh = InferAfterSFT(target_action=trainer.ppls[0].id)\n",
    "trainer.register_event_handler(eh)\n",
    "trainer.run()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 任务恢复\n",
    "\n",
    "针对网络中断，服务不稳定等重试无法覆盖的场景，SDK提供了`resume()`以恢复训练过程，这里以LLMFinetune中断后恢复为例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO] [12-07 21:54:28] data_source.py:1044 [t:139789057857344]: data releasing, keep rolling\n",
      "[INFO] [12-07 21:54:30] data_source.py:1044 [t:139789057857344]: data releasing, keep rolling\n",
      "[INFO] [12-07 21:54:33] data_source.py:1044 [t:139789057857344]: data releasing, keep rolling\n",
      "[INFO] [12-07 21:54:38] data_source.py:1044 [t:139789057857344]: data releasing, keep rolling\n",
      "[INFO] [12-07 21:54:41] data_source.py:1053 [t:139789057857344]: data releasing succeeded\n",
      "[INFO] [12-07 21:54:44] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 21:55:14] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 21:55:46] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[ERROR] [12-07 21:56:17] console_requestor.py:42 [t:139789057857344]: console api request failed with error code: 500002, err msg: auth failed, no access, please check the api doc\n"
     ]
    },
    {
     "ename": "APIError",
     "evalue": "api return error, code: 500002, msg: auth failed, no access",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n",
      "\u001b[0;31mAPIError\u001b[0m                                  Traceback (most recent call last)\n",
      "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n",
      "\u001b[0;32m----> 1\u001b[0m trainer\u001b[39m.\u001b[39mrun()\n",
      "\n",
      "File \u001b[0;32m/data/anaconda3/lib/python3.11/site-packages/qianfan/trainer/finetune.py:173\u001b[0m, in \u001b[0;36mLLMFinetune.run\u001b[0;34m(self, **kwargs)\u001b[0m\n",
      "\u001b[1;32m    167\u001b[0m kwargs[\u001b[39m\"\u001b[39m\u001b[39mbackoff_factor\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m kwargs\u001b[39m.\u001b[39mget(\n",
      "\u001b[1;32m    168\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mbackoff_factor\u001b[39m\u001b[39m\"\u001b[39m, get_config()\u001b[39m.\u001b[39mTRAINER_STATUS_POLLING_BACKOFF_FACTOR\n",
      "\u001b[1;32m    169\u001b[0m )\n",
      "\u001b[1;32m    170\u001b[0m kwargs[\u001b[39m\"\u001b[39m\u001b[39mretry_count\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m kwargs\u001b[39m.\u001b[39mget(\n",
      "\u001b[1;32m    171\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mretry_count\u001b[39m\u001b[39m\"\u001b[39m, get_config()\u001b[39m.\u001b[39mTRAINER_STATUS_POLLING_RETRY_TIMES\n",
      "\u001b[1;32m    172\u001b[0m )\n",
      "\u001b[0;32m--> 173\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mresult[\u001b[39m0\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mppls[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mexec(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n",
      "\u001b[1;32m    174\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\n",
      "\n",
      "File \u001b[0;32m/data/anaconda3/lib/python3.11/site-packages/qianfan/resources/requestor/console_requestor.py:46\u001b[0m, in \u001b[0;36mConsoleAPIRequestor._check_error\u001b[0;34m(self, body)\u001b[0m\n",
      "\u001b[1;32m     41\u001b[0m err_msg \u001b[39m=\u001b[39m body\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39merror_msg\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mno error message found in response body\u001b[39m\u001b[39m\"\u001b[39m)\n",
      "\u001b[1;32m     42\u001b[0m log_error(\n",
      "\u001b[1;32m     43\u001b[0m     \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mconsole api request failed with error code: \u001b[39m\u001b[39m{\u001b[39;00merror_code\u001b[39m}\u001b[39;00m\u001b[39m, err msg:\u001b[39m\u001b[39m\"\u001b[39m\n",
      "\u001b[1;32m     44\u001b[0m     \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00merr_msg\u001b[39m}\u001b[39;00m\u001b[39m, please check the api doc\u001b[39m\u001b[39m\"\u001b[39m\n",
      "\u001b[1;32m     45\u001b[0m )\n",
      "\u001b[0;32m---> 46\u001b[0m \u001b[39mraise\u001b[39;00m errors\u001b[39m.\u001b[39mAPIError(error_code, err_msg)\n",
      "\n",
      "\u001b[0;31mAPIError\u001b[0m: api return error, code: 500002, msg: auth failed, no access"
     ]
    }
   ],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO] [12-07 22:00:58] actions.py:390 [t:139789057857344]: [train_action] resume from created job 17304/9077\n",
      "[INFO] [12-07 22:00:58] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:01:29] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:02:00] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:02:30] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:03:01] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:03:31] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:04:02] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:17:51] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:18:22] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:18:52] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:19:23] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:19:54] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:32:12] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:32:43] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:33:14] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:33:44] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:34:15] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:34:46] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:35:17] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:35:48] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:36:18] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: FINISH, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:36:18] actions.py:370 [t:139789057857344]: [train_action] fine-tune job has ended: 9077 with status: FINISH\n",
      "[INFO] [12-07 22:36:19] model.py:183 [t:139789057857344]: check train job: 17304/9077 status before publishing model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO] [12-07 22:36:20] model.py:199 [t:139789057857344]: model publishing keep polling, current status FINISH\n",
      "[INFO] [12-07 22:36:20] model.py:233 [t:139789057857344]: model ready to publish\n",
      "[INFO] [12-07 22:36:21] model.py:239 [t:139789057857344]: check model publish status: Creating\n",
      "[INFO] [12-07 22:36:51] model.py:239 [t:139789057857344]: check model publish status: Ready\n",
      "[INFO] [12-07 22:36:51] model.py:241 [t:139789057857344]: model 10248/12701 published successfully\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<qianfan.trainer.finetune.LLMFinetune at 0x7f22c4dd0210>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.resume()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
